# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
from tqdm import tqdm
from collections import Counter
from typing import Dict, List, Optional

import paddle
from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM
from ...core import MMDataset, register


def load_model(model_name: str):
    """Load the model and tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, dtype="float16")
    return tokenizer, model


def parse_model_output(output_text: str) -> Optional[dict]:
    """Parse the JSON format output generated by the model."""
    try:
        # Locate the first '{' and the last '}'
        start_idx = output_text.find('{')
        end_idx = output_text.rfind('}') + 1
        if start_idx == -1 or end_idx == 0:
            return None

        json_str = output_text[start_idx:end_idx]
        parsed_data = json.loads(json_str)

        # Ensure all required keys are present
        required_keys = ['colors', 'shapes', 'position', 'size', 'direction',
                         'relationships', 'actions', 'categories']

        # Add default values for missing keys
        for key in required_keys:
            if key not in parsed_data:
                parsed_data[key] = ["N/A"] if key in ['colors', 'shapes', 'relationships',
                                                      'actions', 'categories'] else "N/A"

        return parsed_data
    except json.JSONDecodeError:
        print(f"Failed to parse output: {output_text}")
        return None


def clean_and_count(all_info: List[dict]) -> Dict[str, Counter]:
    """Clean and count the frequency of each category."""
    cleaned_info = {}

    # Define the categories and their corresponding keys
    categories = {
        'Colors': 'colors',
        'Shapes': 'shapes',
        'Position': 'position',
        'Size': 'size',
        'Direction': 'direction',
        'Relationships': 'relationships',
        'Actions': 'actions',
        'Categories': 'categories'
    }

    # Initialize counters for each category
    for category in categories.keys():
        cleaned_info[category] = Counter()

    # Process each item in the dataset
    for item in all_info:
        if not item:  # Skip invalid data
            continue

        for category, key in categories.items():
            value = item.get(key, "N/A")
            if isinstance(value, list):
                # Handle list values
                for v in value:
                    if v != "N/A":
                        cleaned_info[category][v.strip().lower()] += 1
            else:
                # Handle string values
                if value != "N/A":
                    cleaned_info[category][value.strip().lower()] += 1

    return cleaned_info


# Define the analysis prompt template
ANALYSIS_PROMPT = '''
Extract attributes from the following conversation about an image. Output only a JSON object with these attributes:
- colors: [array of colors mentioned]
- shapes: [array of shapes mentioned]
- position: string describing spatial position
- size: string describing size
- direction: string describing direction
- relationships: [array of relationships between objects]
- actions: [array of actions mentioned]
- categories: [array of objects/items mentioned]

Use "N/A" for any attribute not mentioned in the conversation.

Example Input:
Q: What is in the image?
A: A man is standing next to a car.
Q: What is the color of the car?
A: The car is red.

Example Output:
{
  "colors": ["red"],
  "shapes": ["N/A"],
  "position": "next to the man",
  "size": "N/A",
  "direction": "N/A",
  "relationships": ["man next to car"],
  "actions": ["standing"],
  "categories": ["man", "car"]
}

Conversation to analyze:
{text_input}
'''


@register()
def description_analysis(dataset: MMDataset,
    model_name: str = "Qwen/Qwen2.5-7B",
    batch_size: int = 1) -> Dict:
    """
    Analyze all conversations in the dataset, extract attributes, and compute statistics.

    Args:
        dataset (MMDataset): Dataset object containing image paths and conversations.
        model_name (str): Name of the model to use.
        batch_size (int): Batch size for processing.

    Returns:
        Dict: Processed dataset.
    """
    # Load the model and tokenizer
    tokenizer, model = load_model(model_name)
    model.eval()

    # Store all parsed results
    all_parsed_results = []
    filtered_data = {}

    # Collect data to process
    all_data = []
    print("Collecting data...")
    for item in dataset:
        image_path = item.get("image", "")
        conversations = item.get("conversations", [])

        # Combine all Q&A pairs into a single conversation
        full_caption = ""
        for conversation in conversations:
            question, answer = conversation
            # Clean the question to remove <image> tags
            question = question.replace('<image>\n', '').replace('\n<image>', '').replace('<image>', '')
            full_caption += f"Question: {question.strip()}\nAnswer: {answer.strip()}\n"

        # Create the full prompt
        full_prompt = ANALYSIS_PROMPT.replace("{text_input}", full_caption)

        all_data.append({
            'image_path': image_path,
            'prompt': full_prompt
        })

    total_samples = len(all_data)
    num_batches = (total_samples + batch_size - 1) // batch_size
    print(f"Collected {total_samples} samples, split into {num_batches} batches for processing.")

    # Process data in batches
    for batch_idx in tqdm(range(num_batches), desc="Processing batches"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, total_samples)
        batch_data = all_data[start_idx:end_idx]
        batch_prompts = [item['prompt'] for item in batch_data]
        batch_image_paths = [item['image_path'] for item in batch_data]

        try:
            # Model inference
            input_features = tokenizer(batch_prompts, return_tensors="pd", padding=True)
            with paddle.no_grad():
                outputs = model.generate(**input_features, max_length=512)

            # Decode outputs
            decoded_outputs = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
            print(f"image_path:{batch_image_paths}")
            print(f"prompt:{batch_prompts}")
            print(f"Decoded Outputs for Batch {batch_idx}:\n{decoded_outputs}")

            # Process results for the current batch
            for idx, analysis_result in enumerate(decoded_outputs):
                parsed_result = parse_model_output(analysis_result)
                print(f"Parsed Result for Input {idx} in Batch {batch_idx}: {parsed_result}")
                if parsed_result:
                    all_parsed_results.append(parsed_result)

                    # Update filtered dataset
                    current_item = batch_data[idx]
                    image_path = current_item['image_path']
                    if image_path not in filtered_data:
                        filtered_data[image_path] = {
                            "image": image_path,
                            "conversations": conversations
                        }
            print("*" * 50)

        except Exception as e:
            print(f"Error processing batch {batch_idx + 1}/{num_batches}:")
            print(f"Error details: {str(e)}")
            print("-" * 50)
            continue

    # Clean and count results
    cleaned_info = clean_and_count(all_parsed_results)

    # Output attribute statistics
    print("\n=== Attribute Statistics ===")
    for category, counts in cleaned_info.items():
        print(f"\n{category}:")
        for item, count in counts.most_common(10):  # Display the top 10 most common items for each category
            print(f"  {item}: {count}")

    # Output summary
    print("\n=== Processing Summary ===")
    print(f"Total Q&A pairs: {total_samples}")
    print(f"Successfully parsed: {len(all_parsed_results)}")
    print(f"Number of images involved: {len(filtered_data)}")

    return MMDataset(list(filtered_data.values()))